capture program drop ccn_ESens_het
program define ccn_ESens_het, rclass
	local model_type = "`1'"
	local k = "`2'"
	local inv_var = "`3'"
	local dep_var = "`4'"
	local se_type = "`5'"
	local controls = "`6'"'	
	local swap_type = "`7'"
	local E = "`8'"
	preserve
	capture qui rename `inv_var' I
	capture qui rename `dep_var' S
	qui gen cens_ind = (I==0)
	qui sum I
	local samp = r(N)

	*will need a variable where we drop the I=2080
	*note, I use Imod for I in all cases to compute imr term, across all methods.
	qui gen Imod = I
	qui replace Imod = . if Imod == 2080
	gen cens_indmod = (Imod == 0)
	
	if ("`model_type'"=="Ecounterfactual") {
		qui gen reg_term = `E'*cens_ind + I
		qui gen reg_term_A = A*reg_term
		qui gen I_A = A*I
		qui gen I_L = I*I
		qui gen I_AL = A*I*I
		qui reg S I I_A I_L I_AL dumX_`k'_* `controls' reg_term reg_term_A, noconstant
		return scalar DELTA = _b[reg_term]	
		return scalar BETA = _b[I]
		return scalar BETA_se = _se[I]
		return scalar BETA_A = _b[I_A]
		return scalar BETA_A_se = _se[I_A]
		return scalar BETA_L = _b[I_L]
	 	return scalar BETA_AL = _b[I_AL]
	}

	if ("`model_type'"=="bilog") {
		cumul I, generate(CDF_I)
		sum cens_ind
		local F0 = r(mean)
		local band = $band
		reg CDF_I I if I>0 & I<`band'
		local f0lim = _b[I]
		local E_bilog_lb=-(`F0'/`f0lim')
		return scalar E_bilog_lb = `E_bilog_lb'			
		local E_bilog_ub=((1-`F0')/`f0lim')*(1+ (log(1-`F0')/`F0'))
		return scalar E_bilog_ub = `E_bilog_ub'			
	}

	else if ("`model_type'"=="nohighpeaks") {
		cumul I, generate(CDF_I)
		sum cens_ind
		local F0 = r(mean)
		local band = $band
		reg CDF_I I if I>0 & I<`band'
		local f0lim = _b[I]
		local E_NoHighPeaks=-(1/2)*(`F0'/`f0lim')
		return scalar E_NoHighPeaks = `E_NoHighPeaks'			
	}
	
	if ("`model_type'"=="het_uniform") {
		qui gen imr_term = .
		levelsof X_`k', local(levels)
		foreach i of local levels {
			qui sum Imod if Imod>0 & X_`k' == `i'
			local pos_mean = r(mean)
			qui sum cens_indmod if  X_`k' == `i'
			local bunching = r(mean)
			qui replace imr_term = -`pos_mean'*(`bunching'/(1-`bunching')) if  X_`k' == `i'
		}
		sum imr_term if Imod==0
		return scalar E_uniform = r(mean)
	}

	if ("`model_type'" == "het_tobit") {
		*Run Tobit
		qui gen imr_term = .
		levelsof X_`k', local(levels)
		foreach i of local levels {
			qui capture tobit Imod if X_`k' == `i', ll(0)
			if _rc==0 {
				qui matrix coeff_mat_`i' = e(b)
				qui matrix var_mat_`i' = e(V)
				qui predict imr_temp, e(.,0)
				qui replace imr_term = imr_temp if X_`k' == `i'
				qui drop imr_temp
				local sigma_`i' = coeff_mat_`i'[1,2]^0.5
			}
			else {
				drop if X_`k'==`i'
			}
		}
		*get an imr_term (cens_exp) for each keep
		levelsof X_`k', local(levels)
		foreach i of local levels {
			qui sum imr_term if X_`k' == `i'
			local imr_term_`i' = r(mean)
		}
		sum imr_term if Imod==0
		return scalar E_het_tobit = r(mean)
	}

	if ("`model_type'" == "symmetric") {
		*generate a variable that tells whether cluster is censored more than 50%
		levelsof X_`k', local(levels)
		foreach i of local levels {
			local switch_`i' = 0
			qui sum cens_indmod if X_`k' == `i', d
			if (r(p50) == 1) {
				local switch_`i' = 1
			}
		}
		*Get the Censored Expectations
		*Step 1 : get the censorted expectations assuming symmetry.
		*These will not be correct within a cluster if the censoring is more than 50%
		gen cens_exp = .
		foreach i of local levels {
			*Step 1 -- find F_{X|Z=z)(0) -- this is just the proportion of observations with Z=z who have X_i = 0
			qui sum cens_indmod if X_`k' == `i'
			*qui gen hatF_0_`i' = r(mean)
			qui local hatF_0_`i' = r(mean)
			qui local op_hatF_0_`i' = 1- `hatF_0_`i''

			*Step 2: find the 1-hatF_0_i percentile amon X (I) such that Z = i
			qui cumul Imod if X_`k' == `i', gen(Icumul_`i') equal
			qui sum Icumul_`i' if X_`k' == `i'
			local Icumul_min = r(min)
			if `op_hatF_0_`i'' > `Icumul_min' { 
				qui sum Imod if X_`k' == `i' & Icumul_`i' <= `op_hatF_0_`i''
				local step2_`i' = r(max)
			}
			if `op_hatF_0_`i'' <= `Icumul_min' {
				local step2_`i' = 0
			}
			*Step 3: find the average value of I conditional on X >= step2 and Z = i
			qui sum Imod if Imod >= `step2_`i'' & X_`k' == `i'
			local step3_`i' = r(mean)

			*Step 4: Calculate the key censored expectation, conditional on Z = i
			local cens_expec_`i' = `step2_`i''- `step3_`i''
			qui replace cens_exp = `cens_expec_`i'' if X_`k' == `i'

		}
		*also want switch codes for cases where the below qreg does not work.
		foreach i of local levels {
			qui capture qreg Imod if X_`k' == `i', q(`op_hatF_0_`i'')
			if (_rc != 0) {
				local switch_`i' = 1
				local qreg_switch_`i' = 1
			}
		}
		*get total number of switches
		local tot_switch = 0
		foreach i of local levels {
			local tot_switch = `tot_switch' + `switch_`i''
		}
		return scalar CLUS_SWITCH = `tot_switch'
		if (`tot_switch'>0) {
			if ("`swap_type'" == "tobit") {
				*Now use standard Tobit
				qui tobit Imod dumX_`k'_*, ll(0) noconstant
				matrix coeff_mat_tob = e(b)
				matrix var_mat_tob = e(V)
				qui predict imr_term_tob, e(.,0)
				local sigma = coeff_mat_tob[1,`k'+2]^0.5

				*get an imr_term (cens_exp) for each keep
				foreach i of local levels {
					qui sum imr_term if X_`k' == `i'
					local imr_term_`i' = r(mean)
					*swap in if the switch term is on
					if (`switch_`i'' == 1) {
						local cens_exp_`i' = `imr_term_`i''
						replace cens_exp = `imr_term_`i'' if X_`k' == `i'
					}
				}

			}
			if ("`swap_type'" == "het_tobit") {
				*Run Tobit
				qui gen imr_term = .
				foreach i of local levels {
					if (`switch_`i'' == 1) {
						qui capture tobit Imod if X_`k' == `i', ll(0)
						if _rc==0 {
							qui matrix coeff_mat_`i' = e(b)
							qui matrix var_mat_`i' = e(V)
							qui predict imr_temp, e(.,0)
							qui replace imr_term = imr_temp if X_`k' == `i'
							qui drop imr_temp
							local sigma_`i' = coeff_mat_`i'[1,2]^0.5
						}
						else {
							drop if X_`k'==`i'
						}
						*get an imr_term (cens_exp) for each keep
						qui sum imr_term if X_`k' == `i'
						local imr_term_`i' = r(mean)
						*local cens_exp_`i' = `imr_term_`i''
						replace cens_exp = `imr_term_`i'' if X_`k' == `i' //asign tobit piece to relevant bucket.
					}
				}
				
			}
		}
		sum cens_exp if Imod==0
		return scalar E_symmetric = r(mean)
	}
	restore
end
